import copy
import json
import pickle
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
from torch.nn import BCEWithLogitsLoss
import pandas as pd
import numpy as np
import argparse
from sklearn.metrics import classification_report
from tqdm import tqdm
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer
import torch.nn.functional as F
import torch.nn as nn



categories = ['Category-Weather', 'Category-Games', 'Category-Finance', 'Category-Travel', 'Category-Reference', 'Category-Widgets', 'Category-Utilities', 'Category-Medical', 'Category-Navigation', 'Category-Productivity',
              'Category-Snippets', 'Category-Sports', 'Category-Books', 'Cayegory-Social Networking', 'Category-Lifestyle', 'Category-Education', 'Category-Graphics & Design', 'Category-Shopping', 'Category-News', 'Category-Jailbreaking',
              'Category-Kids', 'Category-Entertainment', 'Category-Photo & Video', 'Category-Business', 'Category-Health & Fitness', 'Category-Music', 'Category-Development Tools', 'Category-Food & Drink']

class MultiLabelDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.mlb = MultiLabelBinarizer(classes=categories)
        if 'label' in data[0]:
            self.labels = self.mlb.fit_transform([item['label'] for item in data])
        else:
            self.labels = [[0] * len(categories)] * len(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item['response'].strip('{}').replace('query', '').replace('step_by_step_description', '').replace("'':", '')
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            return_token_type_ids=True,
            return_attention_mask=True,
            truncation=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'labels': torch.tensor(self.labels[idx], dtype=torch.float)
        }
class ATLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, logits, labels):
        # TH label
        th_label = torch.zeros((labels.shape[0], labels.shape[1] + 1), dtype=torch.float).to(labels)
        labels = torch.cat((torch.zeros(labels.shape[0], 1, dtype=th_label.dtype).to(labels), labels), dim=1)

        th_label[:, 0] = 1.0
        labels[:, 0] = 0.0

        p_mask = labels + th_label
        n_mask = 1 - labels

        # Rank positive classes to TH
        logit1 = logits - (1 - p_mask) * 1e30
        loss1 = -(F.log_softmax(logit1, dim=-1) * labels).sum(1)

        # Rank TH to negative classes
        logit2 = logits - (1 - n_mask) * 1e30
        loss2 = -(F.log_softmax(logit2, dim=-1) * th_label).sum(1)

        # Sum two parts
        loss = loss1 + loss2
        loss = loss.mean()
        return loss

    def get_label(self, logits, num_labels=8):
        th_logit = logits[:, 0].unsqueeze(1)
        output = torch.zeros_like(logits).to(logits)
        mask = (logits > th_logit)
        max_indices = logits[:, 1:].argmax(dim=1) + 1 
        row_indices = torch.arange(logits.size(0))
        mask[row_indices, max_indices] = True

        if num_labels > 0:
            top_v, _ = torch.topk(logits, num_labels, dim=1)
            top_v = top_v[:, -1]
            mask = (logits >= top_v.unsqueeze(1)) & mask
        output[mask] = 1.0
        output[:, 0] = (output.sum(1) == 0.).to(logits)
        return output[:, 1:].cpu().numpy()


def evaluate_model(model, val_loader):
    model.eval()
    total_val_loss = 0
    all_true_labels = []
    all_pred_labels = []

    with torch.no_grad():
        for batch in val_loader:
            ids = batch['ids'].to(device, dtype=torch.long)
            mask = batch['mask'].to(device, dtype=torch.long)
            token_type_ids = batch['token_type_ids'].to(device, dtype=torch.long)
            labels = batch['labels'].to(device, dtype=torch.float)

            outputs = model(ids, attention_mask=mask, token_type_ids=token_type_ids)
            loss = loss_function(outputs.logits, labels)
            total_val_loss += loss.item()

            logits = outputs.logits
            true_labels = labels.cpu().numpy()

            pred_labels = loss_function.get_label(logits).astype(int)

            all_true_labels.append(true_labels)
            all_pred_labels.append(pred_labels)

    all_true_labels = np.vstack(all_true_labels)
    all_pred_labels = np.vstack(all_pred_labels)

    report = classification_report(
        all_true_labels,
        all_pred_labels,
        target_names=categories,
        zero_division=0,
        output_dict=True
    )

    print(f'Validation Loss: {total_val_loss / len(val_loader)}')
    print('F1 Score per Category:')
    for category, metrics in report.items():
        if category in categories: 
            print(f"{category}: ", end="")
            for k,v in metrics.items():
                print(f"{k}:{round(v, 4)}", end=' ')
            print()
    print(f"Overall F1 Score: {report['weighted avg']['f1-score']:.4f}")
    return report['weighted avg']['f1-score']


def report_model(model, test_loader, unlabeled_shortcut):
    model.eval()
    all_pred_labels = []

    with torch.no_grad():
        for batch in test_loader:
            ids = batch['ids'].to(device, dtype=torch.long)
            mask = batch['mask'].to(device, dtype=torch.long)
            token_type_ids = batch['token_type_ids'].to(device, dtype=torch.long)

            outputs = model(ids, attention_mask=mask, token_type_ids=token_type_ids)
            logits = outputs.logits
            pred_labels = loss_function.get_label(logits).astype(int)
            all_pred_labels.append(pred_labels)
    all_pred_labels = np.vstack(all_pred_labels)

    for i in range(len(unlabeled_shortcut)):
        labeld_idxs = np.nonzero(all_pred_labels[i, :])[0]
        labels = []
        for idx in labeld_idxs:
            labels.append(categories[idx])
        unlabeled_shortcut[i]['label'] = copy.copy(labels)

    return unlabeled_shortcut


def train_and_evaluate(model, train_loader, val_loader, epochs=10):
    best_F1 = 0
    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        for batch in tqdm(train_loader, mininterval=30):
            optimizer.zero_grad()

            ids = batch['ids'].to(device, dtype=torch.long)
            mask = batch['mask'].to(device, dtype=torch.long)
            token_type_ids = batch['token_type_ids'].to(device, dtype=torch.long)
            labels = batch['labels'].to(device, dtype=torch.float)

            outputs = model(ids, attention_mask=mask, token_type_ids=token_type_ids)
            loss = loss_function(outputs.logits, labels)
            total_train_loss += loss.item()

            loss.backward()
            optimizer.step()
            scheduler.step()

        print(f'Epoch {epoch + 1}, Train Loss: {total_train_loss / len(train_loader)}')

        F1 = evaluate_model(model, val_loader)
        if F1 > best_F1:
            best_F1 = F1
            torch.save(model.state_dict(), f'./deberta_{MAX_LEN}_best_classify.pkl')

with open('./query_and_description_2.json', 'r') as fp:
    ALL_shortcut_details = json.load(fp)

with open('./Routinehub_shortcut2category.pkl', 'rb') as fp:
    Routinehub_shortcut2category = pickle.load(fp)

labeled_shortcut = []
unlabeled_shortcut = []

for shorcut in ALL_shortcut_details:
    if shorcut['key'] in Routinehub_shortcut2category.keys():
        shorcut['label'] = Routinehub_shortcut2category[shorcut['key']]
        labeled_shortcut.append(shorcut)
    else:
        unlabeled_shortcut.append(shorcut)


parser = argparse.ArgumentParser(description='classifying shortcuts\' categories.')
parser.add_argument(
    '--load_path', type=str,default=""
)
parser.add_argument(
    '--device', type=int, default=0
)
parser.add_argument(
    '--MAX_LEN', type=int, default=128
)

args = parser.parse_args()

model_path = "/Pretrained_Language_Models/deberta-v3-base"

tokenizer = AutoTokenizer.from_pretrained(model_path)

MAX_LEN =args.MAX_LEN

BATCH_SIZE = 16
Epochs = 5
labeled_dataset = MultiLabelDataset(labeled_shortcut, tokenizer, max_length=MAX_LEN)
train_data, val_data = train_test_split(labeled_dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)


unlabeled_dataset = MultiLabelDataset(unlabeled_shortcut, tokenizer, max_length=MAX_LEN)

test_loader = DataLoader(unlabeled_dataset, batch_size=4, shuffle=False, drop_last=False)

model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=len(categories) + 1)

new_layer = ["pooler", "classifier"]
optimizer_grouped_parameters = [
    {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in new_layer)], },
    {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in new_layer)], "lr": 1e-4},
]

optimizer = AdamW(optimizer_grouped_parameters, lr=3e-5)
total_steps = int(len(train_loader) * Epochs )
print('total_steps:', total_steps)
warmup_steps = int(total_steps * 0.06)
print('warmup_steps:', warmup_steps)

scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

loss_function = ATLoss()


device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
model.to(device)

if args.load_path == "":
    train_and_evaluate(model, train_loader, val_loader, 50)

else:
    model.load_state_dict(torch.load(args.load_path))
    evaluate_model(model, val_loader)
    report_model(model, test_loader, unlabeled_shortcut)

ALL_shortcut_details = labeled_shortcut + unlabeled_shortcut

with open('./deberta_predicted_query_and_description.json', 'w') as fp:
    json.dump(ALL_shortcut_details, fp)

    print('dump done.')
